import json
import os
import sys
import uuid
from urllib.parse import quote
import numpy as np
import pyedflib
import datetime
import oss_util
import edf_repair_unidecode
import traceback
import logging

logger = logging.getLogger(__name__)


def calc_record_size(edf):
    n_channels = edf.signals_in_file
    sample_frequencies = edf.getSampleFrequencies()
    channel_sizes = []
    for i in range(n_channels):
        freq = sample_frequencies[i]
        # Data Record size for each channel is the sampling rate multiplied by 2 bytes
        channel_size = freq * 2
        channel_sizes.append(channel_size)
    # Total size of Data Records is the sum of sizes for all channels
    record_size = sum(channel_sizes)
    return record_size


def split_edf(s3_client, bucket_name, base_url, prefix, edf_file, split_size, filename_prefix):
    # Read the EDF file
    edf = pyedflib.EdfReader(edf_file)
    n_channels = edf.signals_in_file
    signal_headers = edf.getSignalHeaders()

    # Calculate the size of one Data Record
    record_size = calc_record_size(edf)

    # Convert split size from MB to bytes
    split_size_bytes = split_size * 1024 * 1024

    # Calculate the number of Data Records in the current block
    record_duration = int(edf.datarecord_duration)  # EDF data records usually contain 1 second of data

    # Calculate the corresponding time duration
    split_seconds = int(np.ceil(split_size_bytes / record_size))

    # Calculate the number of blocks to split
    duration = edf.getFileDuration() / record_duration
    n_blocks = int(np.ceil(duration / split_seconds))

    infos = []

    # Process each block
    for i in range(n_blocks):
        logger.info(f'{n_blocks}_{i}')

        # Create a new EDF file
        split_file = filename_prefix + f'_{split_size}_part{i + 1}.edf'

        # Calculate the start time for the new EDF file
        start_time = i * split_seconds
        start_datetime = edf.getStartdatetime() + datetime.timedelta(seconds=start_time)

        split_edf = pyedflib.EdfWriter(split_file, n_channels, file_type=pyedflib.FILETYPE_EDFPLUS)

        # Set the start time for the new EDF file
        split_edf.setStartdatetime(start_datetime)

        # Set the signal headers and patient information for the new EDF file
        split_edf.setSignalHeaders(signal_headers)
        split_edf.setPatientName(edf.getPatientName())
        split_edf.setPatientCode(edf.getPatientCode())
        split_edf.setGender(edf.getGender())
        split_edf.setDatarecordDuration(record_duration * 100000)

        try:
            birthdate = edf.getBirthdate()
            if birthdate:
                split_edf.setBirthdate(birthdate)
        except:
            birthdate = edf.getBirthdate(string=False)
            if birthdate:
                split_edf.setBirthdate(birthdate)

        # Calculate the start and end time for the current block
        end_time = min((i + 1) * split_seconds, duration)

        n_records = int(np.ceil((end_time - start_time)))
        # Get the sampling frequencies for each channel
        sample_frequencies = edf.getSampleFrequencies()
        # Calculate the start sample points for each channel
        start_samples = [int(start_time * sf) for sf in sample_frequencies]

        # Write EDF signal data to the new EDF file
        for n in range(n_records):
            for j in range(n_channels):
                start_sample = start_samples[j] + n * sample_frequencies[j]
                samples = edf.readSignal(j, start=int(start_sample), n=int(sample_frequencies[j]))
                split_edf.writePhysicalSamples(samples)

        # Close the new EDF file
        split_edf.close()

        if prefix == '':
            oss_key = filename_prefix + '/' + split_file
        else:
            oss_key = prefix + '/' + filename_prefix + '/' + split_file
        oss_util.upload_file(s3_client, bucket_name, oss_key, split_file)
        infos_data = dict()
        infos_data["filePath"] = base_url + '/' + oss_key
        infos_data["duration"] = n_records
        infos.append(infos_data)

    # Close the original EDF file
    edf._close()

    return split_seconds, infos


def get_index_file(oss_properties, edf_file, split_size):
    s3_client, bucket_name, base_url, prefix = oss_util.get_bucket(oss_properties)
    # File path
    filename_prefix = str(uuid.uuid4()).replace('-', '')
    index_filename = filename_prefix + f'_{split_size}_index.json'
    if prefix == '':
        oss_key = filename_prefix + '/' + index_filename
    else:
        oss_key = prefix + '/' + filename_prefix + '/' + index_filename

    split_seconds, infos = split_edf(s3_client, bucket_name, base_url, prefix, edf_file, split_size, filename_prefix)
    index_data = {
        "tag": "LY-FILE",
        "version": "1",
        "allowCache": True,
        "endList": True,
        "targetDuration": split_seconds,
        "infos": infos
    }
    # Write data to a JSON file
    with open(index_filename, "w") as f:
        json.dump(index_data, f, indent=4)

    oss_util.upload_file(s3_client, bucket_name, oss_key, index_filename)

    result_data = dict()
    result_data["indexFileUrl"] = base_url + '/' + oss_key
    result_data["originalFileUrl"] = os.path.basename(edf_file)

    logger.info(result_data)

    return result_data


def check_file_ok(file_path):
    try:
        pyedflib.EdfReader(file_path)
        flag = True
    except Exception:
        directory, filename = os.path.split(file_path)
        output_file_path = os.path.join(directory, "repair_" + filename)
        edf_repair_unidecode.repair_edf_header(file_path, output_file_path)
        logger.info('Starting to rename the repaired file ---' + file_path + '---' + output_file_path)
        os.remove(file_path)
        # Use os.rename() function to rename
        os.rename(output_file_path, file_path)
        flag = False
    return flag


if __name__ == '__main__':
    # file path
    file_path = 'demo.edf'
    # split file size MB
    split_size = 10
    # oss param
    accessKey = ''
    accessPolicy = ''
    bucketName = ''
    domain = ''
    endpoint = ''
    isHttps = ''
    prefix = ''
    region = ''
    secretKey = ''
    ossProperties = {"accessKey": "minioAssessKey", "accessPolicy": "1", "bucketName": "edfs",
                     "domain": "127.0.0.1:9000", "endpoint": "127.0.0.1:9000", "isHttps": "N",
                     "prefix": "", "region": "", "secretKey": "minioSecretKey"}

    check_file_ok(file_path=file_path)
    # Call the function
    result_data = get_index_file(ossProperties, file_path, split_size)
    print(result_data)
